#include <iostream>
#include <torch/script.h>
#include "hcn.h"

namespace myHcn {

using namespace std;
using namespace torch;

Hcn::Hcn(const string &modelFile)
{
    model = jit::load(modelFile);
    isRunning = false;
}

bool Hcn::setHc()
{
    if (isRunning) {
        hc = output.toTuple()->elements()[1].toTensor();
    } else {
        hc = torch::zeros({4,1,1,96});
    }
    return true;
}

Tensor Hcn::getHc()
{
    return hc;
}

Tensor Hcn::run(float tfValue)
{
    setHc();
    isRunning = true;
    // Tensor input1 = tensor({tfValue});
    // inputs.clear();
    // inputs.push_back(input1);
    // inputs.push_back(hc);
    
    inputs = {tensor({tfValue}),hc};
    output = model.forward(inputs);

    return output.toTuple()->elements()[0].toTensor();
}

int Hcn::shotOverCallback()
{
    isRunning = false;
    return 0;
}

}   // namespace myHcn
